package msgpack

import (
	"fmt"
	"gitee.com/ymofen/gobase"
	"io"
)

const (
	msgpack_type_str = 0x01
	msgpack_str_0xD5 = 0x05
	msgpack_str_0xD9 = 0xD9
	msgpack_str_0xDA = 0xDA
	msgpack_str_0xDB = 0xDB

	msgpack_type_bin = 0x02
	msgpack_bin_0xC4 = 0xC4
	msgpack_bin_0xC5 = 0xC5
	msgpack_bin_0xC6 = 0xC6
)

func WriteMsgPackLen(w io.Writer, typ byte, l uint, byteN int) (int, error) {
	if byteN == 1 {
		return w.Write([]byte{typ, byte(byteN)})
	} else if byteN == 2 {
		return w.Write([]byte{typ, byte(l >> 8), byte(l)})
	} else if byteN == 4 {
		return w.Write([]byte{typ, byte(l >> 24), byte(l >> 16), byte(l >> 8), byte(l)})
	} else {
		return -1, fmt.Errorf("over range length")
	}
}

func WriteMsgPackStr(w io.Writer, s string) (int, error) {
	l := len(s)
	if l <= 31 {
		// (0xD5 << 5)= 0xA0
		b := 0xA0 + byte(l)
		w.Write([]byte{b})
		w.Write([]byte(s))
		return l + 1, nil
	} else if l <= 0xFF {
		w.Write([]byte{msgpack_str_0xD9, byte(l)})
		w.Write([]byte(s))
		return l + 2, nil
	} else if l <= 0xFFFF {
		// 大端法
		WriteMsgPackLen(w, msgpack_str_0xDA, uint(l), 2)
		w.Write([]byte(s))
		return l + 3, nil
	} else if uint(l) <= uint(0xFFFFFFFF) {
		// 大端法
		WriteMsgPackLen(w, msgpack_str_0xDB, uint(l), 4)
		w.Write([]byte(s))
		return l + 5, nil
	} else {
		return 0, fmt.Errorf("over range for str, 0..0xFFFFFFFF")
	}
}

func WriteMsgPackBin(w io.Writer, s []byte) (int, error) {
	l := len(s)
	if l <= 0xFF {
		w.Write([]byte{msgpack_bin_0xC4, byte(l)})
		w.Write(s)
		return l + 2, nil
	} else if l <= 0xFFFF {
		// 大端法
		WriteMsgPackLen(w, msgpack_bin_0xC5, uint(l), 2)
		w.Write(s)
		return l + 3, nil
	} else if uint(l) <= uint(0xFFFFFFFF) {
		// 大端法
		WriteMsgPackLen(w, msgpack_bin_0xC6, uint(l), 4)
		w.Write(s)
		return l + 5, nil
	} else {
		return 0, fmt.Errorf("over range for bin, 0..0xFFFFFFFF")
	}
}

// msgpack解码
//
//	未完整支持
type MsgPackDecode struct {
	AssertType byte
	decodestep int8 // 1:解码长度, 2:开始解码数据,
	msgtype    byte
	size       int
	cache      gobase.BytesBuilder
}

func NewMsgPackDecode() *MsgPackDecode {
	rval := &MsgPackDecode{}
	return rval
}

func (this *MsgPackDecode) Reset() {
	this.decodestep = -1
}

func (this *MsgPackDecode) innerDecodeStrType(v byte) int8 {
	// fixstr stores a byte array whose length is upto 31 bytes:
	//+--------+========+
	//|101XXXXX|  data  |
	//+--------+========+
	// 0xE0 = 1110 000
	if (v >> 5) == msgpack_str_0xD5 {
		this.msgtype = msgpack_str_0xD5
		this.size = int(v << 3 >> 3)
		this.decodestep = 2
		this.cache.Cleanup()

		if this.size == 0 { // 空字符串
			this.decodestep = -1
			return 1
		}
		return 0
	} else if v == msgpack_str_0xD9 {
		//str 8 stores a byte array whose length is upto (2^8)-1 bytes:
		//+--------+--------+========+
		//|  0xd9  |YYYYYYYY|  data  |
		//+--------+--------+========+
		this.msgtype = msgpack_str_0xD9
		this.size = 1
		this.decodestep = 1
		this.cache.Cleanup()
		return 0
	} else if v == msgpack_str_0xDA {
		//str 16 stores a byte array whose length is upto (2^16)-1 bytes:
		//+--------+--------+--------+========+
		//|  0xda  |ZZZZZZZZ|ZZZZZZZZ|  data  |
		//+--------+--------+--------+========+
		this.msgtype = msgpack_str_0xDA
		this.size = 2
		this.decodestep = 1
		this.cache.Cleanup()
		return 0
	} else if v == msgpack_str_0xDB {
		//str 32 stores a byte array whose length is upto (2^32)-1 bytes:
		//+--------+--------+--------+--------+--------+========+
		//|  0xdb  |AAAAAAAA|AAAAAAAA|AAAAAAAA|AAAAAAAA|  data  |
		//+--------+--------+--------+--------+--------+========+
		this.msgtype = msgpack_str_0xDB
		this.size = 4
		this.decodestep = 1
		this.cache.Cleanup()
		return 0
	} else {
		return -1
	}
}

func (this *MsgPackDecode) innerDecodeBinType(v byte) int8 {

	if v == msgpack_bin_0xC4 {
		//bin 8 stores a byte array whose length is upto (2^8)-1 bytes:
		//+--------+--------+========+
		//|  0xc4  |XXXXXXXX|  data  |
		//+--------+--------+========+
		this.msgtype = msgpack_bin_0xC4
		this.size = 1
		this.decodestep = 1
		this.cache.Cleanup()
		return 0
	} else if v == msgpack_bin_0xC5 {
		//bin 16 stores a byte array whose length is upto (2^16)-1 bytes:
		//+--------+--------+--------+========+
		//|  0xc5  |YYYYYYYY|YYYYYYYY|  data  |
		//+--------+--------+--------+========+
		this.msgtype = msgpack_bin_0xC5
		this.size = 2
		this.decodestep = 1
		this.cache.Cleanup()
		return 0
	} else if v == msgpack_bin_0xC6 {
		//bin 32 stores a byte array whose length is upto (2^32)-1 bytes:
		//+--------+--------+--------+--------+--------+========+
		//|  0xc6  |ZZZZZZZZ|ZZZZZZZZ|ZZZZZZZZ|ZZZZZZZZ|  data  |
		//+--------+--------+--------+--------+--------+========+
		this.msgtype = msgpack_bin_0xC6
		this.size = 4
		this.decodestep = 1
		this.cache.Cleanup()
		return 0
	} else {
		return -1
	}
}

func (this *MsgPackDecode) Bytes() []byte {
	return this.cache.Bytes()
}

func (this *MsgPackDecode) InputByte(v byte) int8 {
	if this.decodestep == -1 {
		this.cache.Cleanup()
		this.decodestep = 0
	}
	this.cache.WriteByte(v)
	if this.decodestep == 0 {
		if this.AssertType == msgpack_type_str {
			return this.innerDecodeStrType(v)
		}

		if this.AssertType == msgpack_type_bin {
			return this.innerDecodeBinType(v)
		}

		r := this.innerDecodeStrType(v)
		if r == -1 {
			r = this.innerDecodeBinType(v)
		}
		return r
	} else if this.decodestep == 1 { // 解码长度
		if this.size == this.cache.Len() { // 解码完成
			if this.size == 1 {
				this.size = int(this.cache.Byte(0))
			} else if this.size == 2 {
				this.size = int(this.cache.Uint16_BE(0))
			} else if this.size == 4 {
				this.size = int(this.cache.Uint32_BE(0))
			} else { // 错误的长度
				this.decodestep = -1
				return -1
			}
			this.cache.Cleanup()
			this.decodestep = 2
			if this.size == 0 {
				this.decodestep = -1
				return 1
			}
			return 0
		}
	} else if this.decodestep == 2 { // 解码数据
		if this.size == this.cache.Len() { // 解码完成
			this.decodestep = -1
			return 1
		}
	}
	return 0
}
